Skip to content

Conversation

@thirtiseven
Copy link
Collaborator

@thirtiseven thirtiseven commented Dec 23, 2025

Contributes to #14069
Depends on NVIDIA/spark-rapids-jni#4107

Description

WIP

For discussions and AI review.

This pr adds a partial support for from_protobuf, with a limited feature and framework code.

Checklists

  • This PR has added documentation for new or modified features or behaviors.
  • This PR has added new tests or modified existing tests to cover new code paths.
    (Please explain in the PR description how the new code paths are tested, such as names of the new/existing tests that cover them.)
  • Performance testing has been performed and its results are added in the PR description. Or, an issue has been filed with a link in the PR description.

@thirtiseven
Copy link
Collaborator Author

@greptile full review

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds partial GPU support for Spark's from_protobuf function, enabling decoding of binary protobuf data into Spark SQL structs. The implementation is intentionally limited to simple scalar types (boolean, int32, int64, float, double, string) and targets Spark 3.4.0+, where spark-protobuf is available as an optional external module.

Key changes:

  • GPU expression implementation for protobuf decoding with simple types only
  • Reflection-based shim layer to optionally register protobuf expressions when spark-protobuf module is available
  • Build configuration updates to optionally include spark-protobuf JAR for integration testing
  • Python integration tests with custom data generators for protobuf message encoding

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala Registers protobuf expression rules in the expression mapping
sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala Implements reflection-based shim for ProtobufDataToCatalyst with GPU fallback rules and validation
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala GPU implementation for from_protobuf decoding with null-intolerant behavior
sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala Utility for parsing protobuf FileDescriptorSet (currently unused)
pom.xml Adds maven property to control spark-protobuf dependency inclusion for Spark 3.4+ profiles
integration_tests/pom.xml Configures maven to copy spark-protobuf JAR to dependency directory for tests
integration_tests/src/main/python/protobuf_test.py Integration tests for from_protobuf with parquet round-trip and null input handling
integration_tests/src/main/python/data_gen.py Protobuf message encoder and test data generator for simple scalar types
integration_tests/run_pyspark_from_build.sh Updates test runner to conditionally include spark-protobuf JAR on classpath

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 82 to 84



Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the trailing blank lines at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines 249 to 250


Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the trailing blank lines at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines 83 to 84


Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the trailing blank lines at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines 238 to 239

</project>
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the trailing blank line at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.

Copilot uses AI. Check for mistakes.
Comment on lines +228 to +229


Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the trailing blank line at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines 133 to 136
# if from_protobuf is None:
# pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available")
# if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)):
# pytest.skip("spark-protobuf JVM module is not available on the classpath")
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These commented-out lines should either be removed or uncommented if the checks are needed. Leaving commented-out code in production reduces maintainability. If these checks are intentionally disabled for this WIP PR, consider adding a TODO comment explaining why and when they should be re-enabled.

Suggested change
# if from_protobuf is None:
# pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available")
# if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)):
# pytest.skip("spark-protobuf JVM module is not available on the classpath")
if from_protobuf is None:
pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available")
if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)):
pytest.skip("spark-protobuf JVM module is not available on the classpath")

Copilot uses AI. Check for mistakes.
message_name = "test.Simple"

# Generate descriptor bytes once using the JVM (no protoc dependency)
desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark))
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark))
desc_bytes = with_cpu_session(_build_simple_descriptor_set_bytes)

Copilot uses AI. Check for mistakes.
message_name = "test.Simple"

# Generate descriptor bytes once using the JVM (no protoc dependency)
desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark))
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark))
desc_bytes = with_cpu_session(_build_simple_descriptor_set_bytes)

Copilot uses AI. Check for mistakes.
raise ValueError("Unsupported type for protobuf simple generator: {}".format(spark_type))


class ProtobufSimpleMessageRowGen(DataGen):
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class 'ProtobufSimpleMessageRowGen' does not override 'eq', but adds the new attribute _fields.
The class 'ProtobufSimpleMessageRowGen' does not override 'eq', but adds the new attribute _binary_col_name.

Copilot uses AI. Check for mistakes.
# leaving syntax unset is sufficient/compatible.
try:
fd = fd.setSyntax("proto2")
except Exception:
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except Exception:
except Exception:
# If setSyntax is unavailable (older protobuf-java), we intentionally leave syntax unset.

Copilot uses AI. Check for mistakes.
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 23, 2025

Greptile Overview

Greptile Summary

This PR adds partial GPU support for Spark's from_protobuf() function, enabling GPU-accelerated decoding of protobuf binary data into Spark structs for simple scalar types (boolean, int32, int64, float, double, string).

Key Changes

Core Implementation:

  • GpuFromProtobuf expression that calls JNI layer (Protobuf.decodeToStruct) for GPU-accelerated decoding
  • Supports simple scalar types with various encodings (varint, zigzag, fixed)
  • Implements schema projection optimization: only decodes fields actually used downstream, filling others with null columns
  • Proper null handling: input nulls propagate to output

Shim Layer Integration:

  • ProtobufExprShims registers GPU override rules for ProtobufDataToCatalyst (Spark 3.4.0+)
  • Uses reflection to resolve descriptors via Spark's ProtobufUtils (avoids shaded protobuf class conflicts)
  • Sophisticated field analysis: determines which fields are required by analyzing parent Project expressions
  • Falls back to CPU when unsupported field types are required

Test Infrastructure:

  • Maven configuration stages spark-protobuf jar for Spark 3.4.0+ profiles only
  • Bash script conditionally includes protobuf jar based on Spark version
  • Python helper adds jars to driver classpath for Class.forName() availability
  • Comprehensive integration tests with dynamic descriptor generation (no protoc dependency)
  • ProtobufSimpleMessageRowGen generates correctly-encoded test data following protobuf wire format

Architecture Decisions

  1. Schema Projection Optimization: Only decodes fields that are actually used, reducing GPU work for wide schemas
  2. Reflection-Based Descriptor Resolution: Uses Spark's existing ProtobufUtils to avoid shaded class conflicts
  3. Graceful Degradation: Falls back to CPU for unsupported features (nested types, repeated fields, etc.)

Limitations (Expected for "Partial Support")

  • Only simple scalar types supported (no nested messages, arrays, maps)
  • No repeated fields
  • ENUM support requires enums.as.ints=true option
  • Limited to Spark 3.4.0+ (when spark-protobuf was introduced)

Issues Found

Style Issue: StructType matching in schema projection doesn't check nullable flags, which could cause false positive matches when identifying protobuf output references (line 446-451 in ProtobufExprShims.scala)

Confidence Score: 4/5

  • This PR is safe to merge with minor style improvement recommended. The implementation is well-structured with proper error handling and graceful CPU fallback for unsupported cases.
  • Score of 4 reflects: (1) Solid core implementation with proper resource management and null handling, (2) Comprehensive test coverage including edge cases, (3) Well-designed schema projection optimization, (4) Clean Maven/build integration following existing patterns. Deducted 1 point for minor style issue in StructType matching that could be improved for correctness, though it's unlikely to cause real-world problems given the context.
  • Pay attention to ProtobufExprShims.scala for the struct type matching logic (line 446-451) - consider adding nullable flag comparison for more precise matching.

Important Files Changed

File Analysis

Filename Score Overview
sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala 4/5 Implements GPU-accelerated protobuf decoding for simple scalar types. Clean implementation with proper error handling and null column generation.
sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala 3/5 Complex schema projection logic for from_protobuf. Has sophisticated field analysis but potential edge case issues in projection detection and type matching.
integration_tests/src/main/python/protobuf_test.py 4/5 Comprehensive integration tests for protobuf decoding including null handling and schema projection. Properly handles missing spark-protobuf module.
integration_tests/src/main/python/data_gen.py 5/5 Adds ProtobufSimpleMessageRowGen with correct protobuf wire format encoding including proper varint and sign-extension handling.

Sequence Diagram

sequenceDiagram
    participant User
    participant Spark SQL
    participant ProtobufDataToCatalyst
    participant ProtobufExprShims
    participant GpuFromProtobuf
    participant JNI
    participant cuDF

    User->>Spark SQL: from_protobuf(col, message_name, desc_path)
    Spark SQL->>ProtobufDataToCatalyst: Create expression
    
    Note over ProtobufExprShims: GPU Override Registration
    ProtobufExprShims->>ProtobufExprShims: tagExprForGpu()
    ProtobufExprShims->>ProtobufExprShims: Resolve descriptor via reflection
    ProtobufExprShims->>ProtobufExprShims: Analyze all fields (types, encoding)
    ProtobufExprShims->>ProtobufExprShims: Detect required fields (schema projection)
    ProtobufExprShims->>ProtobufExprShims: Check all required fields supported
    
    alt All required fields supported
        ProtobufExprShims->>GpuFromProtobuf: convertToGpu()
        GpuFromProtobuf->>JNI: Protobuf.decodeToStruct()
        JNI->>cuDF: Decode protobuf binary to struct
        cuDF-->>JNI: Decoded struct (required fields only)
        JNI-->>GpuFromProtobuf: Column vector
        GpuFromProtobuf->>GpuFromProtobuf: Build full schema struct
        GpuFromProtobuf->>GpuFromProtobuf: Fill non-decoded fields with nulls
        GpuFromProtobuf->>GpuFromProtobuf: Apply input nulls
        GpuFromProtobuf-->>Spark SQL: Result column
    else Unsupported fields required
        ProtobufExprShims->>Spark SQL: Fall back to CPU
        Spark SQL->>ProtobufDataToCatalyst: Execute on CPU
        ProtobufDataToCatalyst-->>Spark SQL: Result column
    end
    
    Spark SQL-->>User: DataFrame with decoded data
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. integration_tests/src/main/python/protobuf_test.py, line 133-136 (link)

    syntax: commented-out test skips will cause tests to fail if spark-protobuf is not available

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Haoyang Li <[email protected]>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 199 to 201
val (tid, _) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType)
typeIds(idx) = tid
scales(idx) = encoding.get
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second return value from sparkTypeToCudfId (the encoding) is being ignored and replaced with the encoding determined from the protobuf type. This suggests that sparkTypeToCudfId should not return an encoding at all, or its second return value should be used. Consider refactoring sparkTypeToCudfId to only return the type ID, since the encoding is determined elsewhere based on the protobuf field type.

Suggested change
val (tid, _) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType)
typeIds(idx) = tid
scales(idx) = encoding.get
val (tid, scale) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType)
typeIds(idx) = tid
if (scale != encoding.get) {
willNotWorkOnGpu(
s"Encoding mismatch for '${sf.name}': Spark ${sf.dataType} vs " +
s"Protobuf $protoTypeName")
return
}
scales(idx) = scale

Copilot uses AI. Check for mistakes.
Comment on lines 78 to 88
def sparkTypeToCudfId(dt: DataType): (Int, Int) = dt match {
case BooleanType => (DType.BOOL8.getTypeId.getNativeId, ENC_DEFAULT)
case IntegerType => (DType.INT32.getTypeId.getNativeId, ENC_DEFAULT)
case LongType => (DType.INT64.getTypeId.getNativeId, ENC_DEFAULT)
case FloatType => (DType.FLOAT32.getTypeId.getNativeId, ENC_DEFAULT)
case DoubleType => (DType.FLOAT64.getTypeId.getNativeId, ENC_DEFAULT)
case StringType => (DType.STRING.getTypeId.getNativeId, ENC_DEFAULT)
case BinaryType => (DType.LIST.getTypeId.getNativeId, ENC_DEFAULT)
case other =>
throw new IllegalArgumentException(s"Unsupported Spark type for protobuf(simple): $other")
}
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sparkTypeToCudfId function always returns ENC_DEFAULT as the second tuple element, regardless of the input type. However, this encoding value is not actually used by the caller in ProtobufExprShims.scala (line 199), which discards it with an underscore. Consider simplifying this function to return only the type ID as an Int, since the encoding is determined separately based on the protobuf wire type.

Copilot uses AI. Check for mistakes.
Comment on lines 31 to 82
object ProtobufDescriptorUtils {

def buildMessageDescriptor(
fileDescriptorSetBytes: Array[Byte],
messageName: String): Descriptors.Descriptor = {
val fds = DescriptorProtos.FileDescriptorSet.parseFrom(fileDescriptorSetBytes)
val protos = fds.getFileList.asScala.toSeq
val byName = protos.map(p => p.getName -> p).toMap
val cache = mutable.HashMap.empty[String, Descriptors.FileDescriptor]

def buildFileDescriptor(name: String): Descriptors.FileDescriptor = {
cache.getOrElseUpdate(name, {
val p = byName.getOrElse(name,
throw new IllegalArgumentException(s"Missing FileDescriptorProto for '$name'"))
val deps = p.getDependencyList.asScala.map(buildFileDescriptor _).toArray
Descriptors.FileDescriptor.buildFrom(p, deps)
})
}

val fileDescriptors = protos.map(p => buildFileDescriptor(p.getName))
val candidates = fileDescriptors.iterator.flatMap(fd => findMessageDescriptors(fd, messageName))
.toSeq

candidates match {
case Seq(d) => d
case Seq() =>
throw new IllegalArgumentException(
s"Message '$messageName' not found in FileDescriptorSet")
case many =>
val names = many.map(_.getFullName).distinct.sorted
throw new IllegalArgumentException(
s"Message '$messageName' is ambiguous; matches: ${names.mkString(", ")}")
}
}

private def findMessageDescriptors(
fd: Descriptors.FileDescriptor,
messageName: String): Iterator[Descriptors.Descriptor] = {
def matches(d: Descriptors.Descriptor): Boolean = {
d.getName == messageName ||
d.getFullName == messageName ||
d.getFullName.endsWith("." + messageName)
}

def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = {
val nested = d.getNestedTypes.asScala.iterator.flatMap(walk _)
if (matches(d)) Iterator.single(d) ++ nested else nested
}

fd.getMessageTypes.asScala.iterator.flatMap(walk _)
}
}
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This utility class appears to be unused in the current implementation. The ProtobufExprShims.scala uses Spark's ProtobufUtils via reflection (buildMessageDescriptorWithSparkProtobuf) instead of this custom utility. Consider removing this file if it's not needed for future work, or add documentation explaining its intended purpose if it's meant for upcoming features.

Copilot uses AI. Check for mistakes.
return
# Add driver-class-path for each jar
jar_list = jars.replace(',', ' ').split()
driver_cp = ':'.join(jar_list)
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The classpath separator is hardcoded to colon (:) which is Unix-specific. On Windows, the classpath separator should be semicolon (;). Consider using os.pathsep instead of ':' to make this code platform-independent.

Suggested change
driver_cp = ':'.join(jar_list)
driver_cp = os.pathsep.join(jar_list)

Copilot uses AI. Check for mistakes.
- one column per message field (Spark scalar types)
- a binary column containing a serialized protobuf message containing those fields
This is intentionally limited to the simple scalar types supported in Patch 1:
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment refers to "Patch 1" which may be unclear to readers who are not familiar with the development context. Consider rephrasing to something like "initial implementation" or "this PR" for better clarity.

Suggested change
This is intentionally limited to the simple scalar types supported in Patch 1:
This is intentionally limited to the simple scalar types currently supported by this implementation:

Copilot uses AI. Check for mistakes.
Comment on lines 236 to 240
private def writeTempDescFile(descBytes: Array[Byte]): String = {
val tmp: Path = Files.createTempFile("spark-rapids-protobuf-desc-", ".desc")
Files.write(tmp, descBytes)
tmp.toFile.deleteOnExit()
tmp.toString
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The temporary descriptor file is marked for deletion on JVM exit using deleteOnExit(), but there's no guarantee this will be called, especially in cases of abrupt termination or long-running processes. Consider using a try-finally block or resource management approach to ensure cleanup, or document that these temp files are expected to persist in some scenarios.

Copilot uses AI. Check for mistakes.
from_protobuf = _try_import_from_protobuf()
if from_protobuf is None:
pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available")
if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)):
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)):
if not with_cpu_session(_spark_protobuf_jvm_available):

Copilot uses AI. Check for mistakes.
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
@thirtiseven
Copy link
Collaborator Author

@greptileai full review

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 84 to 99
try {
for (i <- fullSchema.fields.indices) {
if (decodedIdx < decodedFieldIndices.length && decodedFieldIndices(decodedIdx) == i) {
// This field was decoded - extract from decoded struct
fullChildren(i) = decodedStruct.getChildColumnView(decodedIdx).copyToColumnVector()
decodedIdx += 1
} else {
// This field was not decoded - create null column
fullChildren(i) = GpuFromProtobuf.createNullColumn(
fullSchema.fields(i).dataType, numRows)
}
}
cudf.ColumnVector.makeStruct(numRows, fullChildren: _*)
} finally {
fullChildren.foreach(c => if (c != null) c.close())
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an exception occurs during the loop (lines 85-95), the columns that were already created in fullChildren will be leaked because the finally block only closes non-null entries, but an exception could leave fullChildren partially populated. The finally block should be inside the loop body or use a more robust pattern.

Additionally, if cudf.ColumnVector.makeStruct on line 96 throws an exception, the columns in fullChildren will be leaked because they've already been closed in the finally block, but makeStruct hasn't taken ownership yet.

The safer pattern would be:

try {
  for (i <- fullSchema.fields.indices) {
    val child = try {
      if (decodedIdx < decodedFieldIndices.length && decodedFieldIndices(decodedIdx) == i) {
        val col = decodedStruct.getChildColumnView(decodedIdx).copyToColumnVector()
        decodedIdx += 1
        col
      } else {
        GpuFromProtobuf.createNullColumn(fullSchema.fields(i).dataType, numRows)
      }
    } catch {
      case e: Throwable =>
        fullChildren.take(i).foreach(c => if (c != null) c.close())
        throw e
    }
    fullChildren(i) = child
  }
  cudf.ColumnVector.makeStruct(numRows, fullChildren: _*)
} finally {
  fullChildren.foreach(c => if (c != null) c.close())
}

Or use withResource pattern more carefully to ensure proper cleanup ordering.

Comment on lines 446 to 451
case st: StructType =>
// Compare field names and types - StructType equality can be tricky
st.fields.length == fullSchema.fields.length &&
st.fields.zip(fullSchema.fields).forall { case (a, b) =>
a.name == b.name && a.dataType == b.dataType
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The StructType comparison doesn't account for nullable flags on struct fields, which could lead to false positive matches. Two structs with the same field names and types but different nullability should not be considered equal for this matching purpose.

Consider:

Suggested change
case st: StructType =>
// Compare field names and types - StructType equality can be tricky
st.fields.length == fullSchema.fields.length &&
st.fields.zip(fullSchema.fields).forall { case (a, b) =>
a.name == b.name && a.dataType == b.dataType
}
case st: StructType =>
// Compare field names, types, and nullable flags
st.fields.length == fullSchema.fields.length &&
st.fields.zip(fullSchema.fields).forall { case (a, b) =>
a.name == b.name && a.dataType == b.dataType && a.nullable == b.nullable
}

Signed-off-by: Haoyang Li <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant